# -*- coding: utf-8 -*-
from warnings import warn

import numpy as np
import pandas as pd

from ..misc import NeuroKitWarning, find_closest
from ..stats import fit_polynomial
from .epochs_to_df import _df_to_epochs


def _eventrelated_sanitizeinput(epochs, what="ecg", silent=False):
    # Sanity checks
    if isinstance(epochs, pd.DataFrame):
        if "Time" in epochs.columns.values:
            # Assumpe it is a dataframe of epochs created by `epochs_to_df`
            epochs = _df_to_epochs(epochs)  # Convert df back to dict
        else:
            raise ValueError(
                "It seems that you are trying to pass a single epoch to an eventrelated function."
                + 'If this is what you want, please wrap it in a dict: `{"0": epoch}` '
            )

    if not isinstance(epochs, dict):
        raise ValueError(
            "NeuroKit error: " + str(what) + "_eventrelated(): Please specify an input "
            "that is of the correct form i.e., either a dictionary "
            "or dataframe."
        )

    # Warning for long epochs
    if silent is False:
        length_mean = np.mean(
            [np.max(epochs[i].index) - np.min(epochs[i].index) for i in epochs.keys()]
        )
        if length_mean > 10:
            warn(
                str(what) + "_eventrelated():"
                " The duration of your epochs seems quite long. You might want"
                " to use " + str(what) + "_intervalrelated().",
                category=NeuroKitWarning,
            )
    return epochs


def _eventrelated_addinfo(epoch, output={}):
    # Add label
    if "Index" in epoch.columns:
        output["Event_Onset"] = epoch.iloc[np.argmin(np.abs(epoch.index))]["Index"]

    # Add label
    if "Label" in epoch.columns and len(set(epoch["Label"])) == 1:
        output["Label"] = epoch["Label"].values[0]

    # Add condition
    if "Condition" in epoch.columns and len(set(epoch["Condition"])) == 1:
        output["Condition"] = epoch["Condition"].values[0]

    # Add participant_id
    if "Participant" in epoch.columns and len(set(epoch["Participant"])) == 1:
        output["Participant"] = epoch["Participant"].values[0]

    return output


def _eventrelated_sanitizeoutput(data):
    df = pd.DataFrame.from_dict(data, orient="index")  # Convert to a dataframe

    colnames = df.columns.values
    if "Event_Onset" in colnames:
        df = df.sort_values("Event_Onset")
        df = df[["Event_Onset"] + [col for col in df.columns if col != "Event_Onset"]]

    # Move columns to front
    if "Condition" in colnames:
        df = df[["Condition"] + [col for col in df.columns if col != "Condition"]]
    if "Label" in colnames:
        df = df[["Label"] + [col for col in df.columns if col != "Label"]]

    return df


def _eventrelated_rate(epoch, output={}, var="ECG_Rate"):
    # Sanitize input
    colnames = epoch.columns.values
    if len([i for i in colnames if var in i]) == 0:
        warn(
            "Input does not have an `" + var + "` column."
            " Will skip all rate-related features.",
            category=NeuroKitWarning,
        )
        return output

    # Get baseline
    zero = find_closest(0, epoch.index.values, return_index=True)  # Find closest to 0
    baseline = epoch[var].iloc[zero]

    # Get signal
    signal = epoch[var].values[zero + 1 : :]
    index = epoch.index.values[zero + 1 : :]

    # Max / Min / Mean
    output[var + "_Baseline"] = baseline
    output[var + "_Max"] = np.max(signal) - baseline
    output[var + "_Min"] = np.min(signal) - baseline
    output[var + "_Mean"] = np.mean(signal) - baseline
    output[var + "_SD"] = np.std(signal)

    # Time of Max / Min
    output[var + "_Max_Time"] = index[np.argmax(signal)]
    output[var + "_Min_Time"] = index[np.argmin(signal)]

    # Modelling
    # These are experimental indices corresponding to parameters of a quadratic model
    # Instead of raw values (such as min, max etc.)
    _, info = fit_polynomial(signal - baseline, index, order=2)
    output[var + "_Trend_Linear"] = info["coefs"][1]
    output[var + "_Trend_Quadratic"] = info["coefs"][2]
    output[var + "_Trend_R2"] = info["R2"]

    return output
